-
Notifications
You must be signed in to change notification settings - Fork 63
Add diffusion model implementation #408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Preliminary implementation, to be extended with other variants as well.
Codecov ReportAttention: Patch coverage is
|
Thanks @vpratz for the implementation! I plan to add on top of this additional schedules and samplers until the end of the week. |
Thanks for taking a look. Do you know whether your implementation would benefit from the pre-conditioning discussed in Elucidating the Design Space of Diffusion-Based Generative Models, and whether we can combine them in one joint framework? |
Part of the pre-conditioning can be expressed as a special kind of weighting function: see appendix D.1 in here. So yes, the aim would be to have one nice framework! |
I added some more noise schedules and started to make the implementation more general. This is just a first draft, so you @vpratz get an idea, how we could do it. We should discuss this then and how to move forward. |
I added a class Next step would be add stochastic samplers as well. |
Thanks a lot for the fixes, they increase the performance of the new implementation a lot. The old standalone EDM implementation seems to be a little bit better still, but the difference might be down to hyperparameter tuning. I have added the As far as I can tell, the open steps before we finalize this PR are:
Did I miss anything, @arrjon , or do you have any other comments on the current state? |
The performance issue is fixed now, it was mainly due to a missing scaling factor of the log_snr, which goes into the network. I am also implementing the stochastic sampler: it is working for all backend but jax at the moment. After this, only the things @vpratz mentioned are missing. |
The stochastic sampler is now also working for jax. So all features done for the moment! |
Great! Thanks a lot for putting in the work and for the quick fixes! I'll try to add the relevant tests and work on some of the other missing things in the next few days. |
This PR adds a diffusion model implementation for use as an inference network, as discussed in #403. It implements the design introduced as "EDM" in [1]. The overall structure is taken from the
FlowMatching
class.@arrjon @niels-leif-bracher I would appreciate if you take a look and make suggestions regarding how we can incorporate the other diffusion model variants as well. For now, I chose to only expose the
sigma_data
parameter to the end user, and keep everything else private. This should enable us to also change the internals later on and incrementally add new functionality.Please let me know how we want to proceed and how much capacity you have to move this forward, so that we can decide whether we want to include the additional options before we merge, or if we merge early and then incrementally add to it later. I have situated the class in the
experimental
module for now, so that we have some freedom to also change things in the future as we see fit.[1] https://arxiv.org/abs/2206.00364